#!/usr/bin/env python3

import csv
import os
import re
import sys

import pandas

# this is the root directory that we want to look in
parent = "."


def get_country(file):
   if (re.match(r'(.*)us\+(.*)', file)):
      return('us')
   elif (re.match(r'(.*)de\+(.*)', file)):
      return('de')
   elif (re.match(r'(.*)il\+(.*)', file)):
      return('il')
   elif (re.match(r'(.*)co\+(.*)', file)):
      return('co')
   else:
      return('no country match')

def get_disease(file):
   if (re.match(r'(.*)influenza-noise(.*)', file)):
      return('influenza-noise')
   elif (re.match(r'(.*)influenza(.*)', file)):
      return('influenza')
   elif (re.match(r'(.*)chlamydia(.*)', file)):
      return('chlamydia')
   elif (re.match(r'(.*)dengue(.*)', file)):
      return('dengue')
   elif (re.match(r'(.*)malaria(.*)', file)):
      return('malaria')
   elif (re.match(r'(.*)measles(.*)', file)):
      return('measles')
   elif (re.match(r'(.*)pertussis(.*)', file)):
      return('pertussis')
   else:
      return('no disease match')

def get_distance(file):
   if (re.match(r'(.*)d[0-9](.*)', file)):
      match = re.search(r'd[0-9]', file)
      #print(match.start(), match.end())
      return(int(file[match.start()+1:match.end()]))
   else:
      return('no distance')

def get_horizon(file):
   if (re.match(r'(.*)h[0-9]+(.*)', file)):
      match = re.search(r'h[0-9]+', file)
      #print(match.start(), match.end())
      return(int(file[match.start()+1:match.end()]))
   else:
      return('no horizon')

# step 1: get our list of files to go through
rsquared_files = []

for root, dirs, files in os.walk(parent):
   # kludge to make sure we only descend the explicit distance directories
   if (root == 'dbest'):
      continue
   for f in files:
      # see if it's an 'rsquared.tsv' file
      # we match exactly because these are the only files we need right now
      if (re.match(r'^rsquared.tsv$', f)):
         # if it's an rsquared file, put it in this list
         # have to have full_path b/c they are identically named files
         full_path = (os.path.realpath(os.path.join(root,f)))
         rsquared_files.append(full_path)

# make the summary table
rsquared_table = pandas.DataFrame(columns=[  'file',
                                             'country',
                                             'disease',
                                             'max_rsquared',
                                             'distance',
                                             'horizon',
                                             'training',
                                             'staleness'])

for file in rsquared_files:
   # step 1: open file
   # step 2: designate first row and column as headers & not rsquareds
   my_table = pandas.read_table(file, sep="\t", index_col=0)
   # step 3: find maximum rsquared
   max_rsquared = max(my_table.max())
   # step 4: return max rsquared & associated row & column
   max_col = my_table.max().idxmax() ## column containing max # training
   max_row = my_table.max(axis=1).idxmax() ## row containing max # staleness?
   # step 5: get remaining data
   country = get_country(file)
   disease = get_disease(file)
   distance = get_distance(file)
   horizon = get_horizon(file)
   # step 5.1: we only care if it's horizon zero
   #if (get_horizon(file) != 0):
   #   continue
   # step 6: add data into rsquared_summary table
   new_row = [file, country, disease, max_rsquared,
              distance, horizon, max_col, max_row]
   rsquared_table.loc[len(rsquared_table)+1] = new_row


## now we want to take the summary table and understand what overall best is
rsquared_summary = pandas.DataFrame(columns=['country',
                                             'disease',
                                             'max_rsquared',
                                             'distance',
                                             'horizon',
                                             'training',
                                             'staleness'])

def get_maximums(subset):
   if len(subset) > 0:
      max_rsquared = max(subset.max_rsquared)
      # get the max's location (this is the row #)
      max_location = subset.max_rsquared[subset.max_rsquared==max_rsquared]
      max_location = max_location.index.tolist()
      country = subset.ix[max_location,1].max()
      disease = subset.ix[max_location,2].max()
      distance = subset.ix[max_location,4].max()
      horizon = subset.ix[max_location,5].max()
      training = subset.ix[max_location,6].max()
      staleness = subset.ix[max_location,7].max()
      new_row = [country, disease, max_rsquared,
                 distance, horizon, training, staleness]
      return(new_row)

for i in pandas.Series.unique(rsquared_table['country']):
   subset = rsquared_table[(rsquared_table.country==i)]
   if len(subset)>0:
      for j in pandas.Series.unique(subset['disease']):
         subset_dis = subset[(subset.disease==j)]
         rsquared_summary.loc[len(rsquared_summary)+1] = get_maximums(subset_dis)

# export pandas table
rsquared_summary.to_csv(sys.stdout, sep="\t")
